import tqdm
import torch
from torch import nn
from torch import optim

class LP_optimizer(object):
    def __init__(self,
                 model,
                 rel_loss,
                 att_loss,
                 regularizer,
                 optimizer,
                 batch_size,
                 alpha
                 ):
        self.model = model
        self.rel_loss=rel_loss
        self.att_loss=att_loss
        self.regularizer = regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.alpha= alpha
        self.reg_weight_ent=1.3e-07
        self.reg_weight_rel=3.7e-18
        self.reg_weight_attr=0
        
    def _custom_l2_regularization(self):
        pen = None
        if self.reg_weight_ent > 0:
            ent_reg = self.reg_weight_ent * self.model.ent_embeddings.weight.norm(p=2) ** 2
            pen = ent_reg
        if self.reg_weight_rel > 0:
            rel_reg = self.reg_weight_rel * self.model.rel_embeddings.weight.norm(p=2) ** 2
            if pen:
                pen += rel_reg
            else:
                pen = rel_reg
        if hasattr(self.model, "attr_embeddings") and self.reg_weight_attr > 0:
            attr_reg = self.reg_weight_attr * self.model.attr_embeddings.weight.norm(p=2) ** 2
            if hasattr(self.model, "offset_attr_embeddings"):
                attr_reg += self.reg_weight_attr * self.model.offset_attr_embeddings.weight.norm(p=2) ** 2
            if pen:
                pen += attr_reg
            else:
                pen = attr_reg
        if pen:
            pen.backward()
    
    def train_epoch(self,epoch,entity_train_dataloader,att_train_dataloader):
        
        with tqdm.tqdm(total=entity_train_dataloader.dataset.len, unit='ex', disable=False) as bar:
            bar.set_description('train loss--epoch{}'.format(epoch))
            b_begin = 0
            for data_entity,data_att in zip(entity_train_dataloader,att_train_dataloader):
                
                rel_scores, attr_scores=self.model.score(data_entity,data_att)
                rel_loss = self.rel_loss.compute(rel_scores)
                attr_loss = self.att_loss.compute(attr_scores).nan_to_num()
                loss = self.alpha * attr_loss + (1-self.alpha) * rel_loss
                loss.backward()
                self._custom_l2_regularization()            
                self.optimizer.step()

                #get the Clipping here

                b_begin += self.batch_size

                bar.update(entity_train_dataloader.batch_size)
                bar.set_postfix(loss=f'{loss.item():.3f}')